import csv
import json
from typing import List

from copy import deepcopy
import random

import spacy
import en_core_web_md
from tqdm import tqdm
import pickle
from termcolor import cprint

random.seed(42)

dataset = []
with open("../dataset/opendialkg/opendialkg.csv") as csvfile:
    reader = csv.reader(csvfile)

    unique_id = 1000000
    for i, rows in enumerate(reader):
        if i == 0: continue
        dataset.append((rows, unique_id))
        unique_id += 1

print(f"Dataset Size: {len(dataset) }")

with open("../dataset/opendialkg/entity_codebook.pkl", 'rb') as f:
    entity_codebook = pickle.load(f)
reverse_entity_codebook = {v:k for k, v in entity_codebook.items()}
with open("../dataset/opendialkg/relation_codebook.pkl", 'rb') as f:
    relation_codebook = pickle.load(f)

nlp = en_core_web_md.load()

# Map Entity to corresponding code
def map_entity(entity):
    try:
        code = entity_codebook[entity.lower()]
        return code
    except:
        return None

# Map Code to corresponding entity
def map_code(code):
    try:
        entity = reverse_entity_codebook[code]
        return entity
    except:
        return None

# We don't need start_char and end_char anymore? -> No
# 1. Find entity among the candidate entities
# 2. Run spaCy and get candidate of entities from text -> find aliases from the DB
def find_entity(message, entities):
    found_entities = set()
    for entity in entities:
        start_char = message.lower().find(entity.lower())
        if start_char >= 0:
            end_char = start_char + len(entity) - 1
            # entity_code = map_entity(entity)
            # if entity_code is None: continue
            found_entities.add(
                # entity_code
                entity
            )

    doc = nlp(message)
    for ent in doc.ents:
        # entity_code = map_entity(ent.text)
        # if entity_code is None: continue
        start_char, end_char = ent.start_char, ent.end_char
        found_entities.add(
            # entity_code
            ent.text
        )
    return found_entities

def make_symmetric_graph(triplets):
    new_triplets = deepcopy(triplets)
    for triplet in triplets:
        h, r, t = triplet
        if '~' in r:
            new_relation = r[1:]
        else:
            new_relation = '~' + r
        reverse_triplet = [t, new_relation, h]
        if reverse_triplet not in new_triplets:
            new_triplets.append(reverse_triplet)
    return new_triplets

with open("../dataset/opendialkg/opendialkg_triples.txt", 'r') as f:
    entire_triplets = f.readlines()
print(f"# Entire Triples: {len(entire_triplets)}")

def build_database(key="head,relation"):
    database = dict()
    for triplet in entire_triplets:
        _triplet = triplet.strip().split('\t')
        if len(_triplet) < 3:
            continue
        head, relation, tail = _triplet

        if key == "head,relation":
            _id = f"{head}\t{relation}"
        elif key == "head":
            _id = f"{head}"
        
        if _id not in database.keys():
            database[_id] = set()
        database[_id].add((head, relation, tail))
    return database

def make_entire_graph(entities):
    set_of_triplets = set()
    for entity in entities:
        try:
            triplets = head_database[entity]
        except KeyError:
            continue
        set_of_triplets = set_of_triplets.union(triplets)
    
    reverse_triplets = set()
    for triplet in set_of_triplets:
        h, r, t = triplet
        if '~' in r:
            new_relation = r[1:]
        else:
            new_relation = '~' + r
        reverse_triplet = (t, new_relation, h)
        if reverse_triplet not in set_of_triplets and reverse_triplet not in reverse_triplets:
            reverse_triplets.add(reverse_triplet)

    set_of_triplets = set_of_triplets.union(reverse_triplets)
    return set_of_triplets

database = build_database()
head_database = build_database(key="head")

class QAData():
    def __init__(self, 
                 context, 
                 question, 
                 answer, 
                 start_char, 
                 end_char,
                 episode_id,
                 turn_id,
                 unique_id):
        self.context = context
        self.question = question
        self.answer = answer
        self.start_char = start_char
        self.end_char = end_char
        self.episode_id = episode_id
        self.turn_id = turn_id
        self.unique_id = unique_id

    def __str__(self):
        s = ''
        s += f"Context: {self.context}\n"
        s += f"Question: {' '.join(self.question)}\n"
        s += f"Answer: {self.answer}"
        return s

def wrap_qa(context: str, question: List[str], answer: str, episode_id, turn_id, unique_id):
    start_char = context.lower().find(answer.lower())
    end_char = start_char + len(answer) - 1
    if answer.lower() != context[start_char:end_char+1].lower():
        print(answer)
        print(context[start_char:end_char+1])
        return None
    return QAData(context, question, answer, start_char, end_char, episode_id, turn_id, unique_id)

# Replace answer of context
def replace_answer(qa_data, new_answer):
    context = qa_data.context
    start_char = qa_data.start_char
    end_char = qa_data.end_char

    new_context = context[:start_char] + new_answer + context[end_char+1:]
    start_char = new_context.lower().find(new_answer.lower())
    end_char = start_char + len(new_answer) - 1
    if new_answer.lower() != new_context[start_char:end_char+1].lower():
        print(new_answer)
        print(new_context[start_char:end_char+1])
        return None
    return QAData(new_context, qa_data.question, new_answer, start_char, end_char, qa_data.episode_id, qa_data.turn_id, qa_data.unique_id)

def augment_qa(qa_dataset):
    augmented_dataset = []
    for data in qa_dataset:
        key = '\t'.join(data.question)
        if key not in database.keys():
            continue
        candidate_triplets = database[key]
        candidate_answers = [t[-1] for t in candidate_triplets]
        if len(candidate_answers) > 10: # To avoid too many QA data from entities with "high degree"
            candidate_answers = random.sample(candidate_answers, 10)
        for ans in candidate_answers:
            if ans != data.answer:
                _data = replace_answer(data, ans)
                if _data is not None:
                    augmented_dataset.append(_data)
    return qa_dataset + augmented_dataset

def sequential_print(something_list):
    freq = len(something_list) // 10
    for i, item in enumerate(something_list):
        if i % freq == 0:
            print(item)
            print()

def to_jsonl_format(data):
    jsonl_data = {
        'episode_id': data.episode_id,
        'turn_id': data.turn_id,
        'context': data.context,
        'question': data.question,
        'answer': data.answer,
        'start_char': data.start_char,
        'end_char': data.end_char,
        'unique_id': data.unique_id,
    }
    return jsonl_data

"""
Overall Blueprint
1. Take Metadata and extract triplets
"""
def _preprocess(dataset):
    total_graph_size = 0
    total_num = 0

    qa_dataset = []
    episode_id = 0
    # dataset = dataset[:100]
    for i, rows in enumerate(tqdm(dataset, desc="Preprocessing...")):
        rows, unique_id = rows
        turn_id = 0
        ## Dialogue Histories
        history = []
        ## Triplets and Entities
        history_entities = set()
        entities = set() # From metada:path
        entire_triplets = []

        pp_rows = json.loads(rows[0])
        # Gather entities first
        for row in pp_rows:
            if 'metadata' in row.keys():
                if 'path' in row['metadata'].keys():
                    triplets = row['metadata']['path'][1]
                    for triplet in triplets:
                        entities.add(triplet[0])
                        entities.add(triplet[-1])
                    entire_triplets.extend(triplets)

        for row in pp_rows:
            # if 'metadata' in row.keys():
            #     if 'path' in row['metadata'].keys():
            #         triplets = row['metadata']['path'][1]
            #         for triplet in triplets:
            #             entities.add(triplet[0])
            #             entities.add(triplet[-1])
            if 'message' in row.keys():
                message = row['message']
                if row['sender'] == 'assistant':
                    # Wrap-up and make the data
                    dialog_text = '\n'.join(history)

                    # sym_triplets = make_symmetric_graph(entire_triplets)
                    sym_triplets = make_entire_graph(entities)

                    total_graph_size += len(sym_triplets)
                    total_num += 1

                    found_entities = find_entity(dialog_text, list(entities))
                    for entity1 in list(found_entities):
                        for triplet in sym_triplets:
                            if entity1 == triplet[0]:
                                entity2 = triplet[2]
                                pos = message.lower().find(entity2.lower())
                                if pos < 0:
                                    continue
                                
                                # history becomes context
                                # triplet[0], triplet[1] becomes question
                                # entity2 becomes answer
                                _history = history + [message]
                                _qa_data = wrap_qa('\n'.join(_history), triplet[0:2], entity2, episode_id, turn_id, unique_id)
                                if _qa_data is not None:
                                    qa_dataset.append(_qa_data)
                    turn_id += 1
                    history = [] # Initialize history 
                    """ WHY initialize history? 
                    Sometimes, the former "gold response" contains the answer of question so that
                    hinders to correctly evaluate the generated response.
                    """  
                else:
                    history.append(message)
        episode_id += 1

        if i % 1000 == 0:
            print(f"Average Num: {total_graph_size / total_num}")
    # QA 틀 제작
    print(f"Original Size: {len(qa_dataset)}")

    # QA 틀을 바탕으로 다양한 answer의 training dataset 생성!
    qa_dataset = augment_qa(qa_dataset)

    qa_dataset = [to_jsonl_format(data) for data in qa_dataset]
    print(f"Final Size: {len(qa_dataset)}")

    random.shuffle(qa_dataset)

    train_ratio = 0.9
    train_len = int(len(qa_dataset) * train_ratio)
    train_dataset = qa_dataset[:train_len]
    test_dataset = qa_dataset[train_len:]

    filename_out = "dataset/train_entire.jsonl"
    with open(filename_out, 'w') as outfile:
        for data in train_dataset:
            outfile.write(json.dumps(data) + '\n')

    filename_out = "dataset/test_entire.jsonl"
    with open(filename_out, 'w') as outfile:
        for data in test_dataset:
            outfile.write(json.dumps(data) + '\n')

_preprocess(dataset)